"""
Phase 4: Thresholds & Nonlinearities
======================================
1. OADR threshold test (spline regression with candidate knots)
2. Life expectancy threshold (connect to Project 1 LE turning point)
3. Speed of aging (Δold_dep as predictor)
4. Income interaction (Z × log GDP/cap — "getting old before getting rich")
5. Time-varying relationship (rolling windows, structural breaks)

Input:  japanification/data/processed/japan_panel_indexed.csv
Output: japanification/output/tables/phase4_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 4: Thresholds & Nonlinearities")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    dep_var = 'japan_index_2c'
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    all_results = []

    # =================================================================
    # 1. OADR threshold test (spline regression)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  1. OADR Threshold Test")
    print("=" * 70)

    threshold_rows = []
    for knot in [0.15, 0.20, 0.25, 0.30]:
        # Create spline: old_dep_below = min(old_dep, knot), old_dep_above = max(0, old_dep - knot)
        df[f'oadr_below_{int(knot*100)}'] = df['old_dep'].clip(upper=knot)
        df[f'oadr_above_{int(knot*100)}'] = (df['old_dep'] - knot).clip(lower=0)

        spline_vars = [f'oadr_below_{int(knot*100)}', f'oadr_above_{int(knot*100)}'] + controls
        est = df.dropna(subset=[dep_var] + spline_vars).copy()

        if len(est) < 200:
            continue

        model = PanelGLS()
        model.fit(est[dep_var].values, est[spline_vars].values,
                  est['iso3'].values, est['year'].values)

        below_coef = model.beta[0]
        above_coef = model.beta[1]
        below_p = model.pvalues[0]
        above_p = model.pvalues[1]

        threshold_rows.append({
            'knot': knot,
            'coef_below': below_coef,
            'p_below': below_p,
            'coef_above': above_coef,
            'p_above': above_p,
            'R_squared': model.r_squared,
            'N': model.n_obs,
            'above_minus_below': above_coef - below_coef,
        })

        sig_above = '***' if above_p < 0.01 else '**' if above_p < 0.05 else '*' if above_p < 0.1 else ''
        print(f"  Knot={knot:.0%}: below={below_coef:.3f} (p={below_p:.3f}), "
              f"above={above_coef:.3f}{sig_above} (p={above_p:.3f}), "
              f"R²={model.r_squared:.4f}")

    threshold_df = pd.DataFrame(threshold_rows)
    threshold_df.to_csv(TABLE_DIR / "phase4_oadr_thresholds.csv", index=False)

    # Best threshold (highest R²)
    if len(threshold_df) > 0:
        best = threshold_df.loc[threshold_df['R_squared'].idxmax()]
        print(f"\n  Best knot: {best['knot']:.0%} (R²={best['R_squared']:.4f})")

    # =================================================================
    # 2. Life expectancy threshold
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  2. Life Expectancy Threshold")
    print("=" * 70)

    if 'life_expectancy' in df.columns:
        df['le_sq'] = df['life_expectancy'] ** 2
        le_vars = ['life_expectancy', 'le_sq'] + controls
        est_le = df.dropna(subset=[dep_var] + le_vars).copy()

        if len(est_le) >= 200:
            m_le, r_le = fit_and_report(
                est_le[dep_var].values, est_le[le_vars].values,
                est_le['iso3'].values, est_le['year'].values,
                le_vars, "LE Quadratic → Japanification"
            )
            all_results.append(r_le)

            # Compute turning point: -β₁/(2β₂)
            le_idx = le_vars.index('life_expectancy')
            le_sq_idx = le_vars.index('le_sq')
            if m_le.beta[le_sq_idx] != 0:
                turning_point = -m_le.beta[le_idx] / (2 * m_le.beta[le_sq_idx])
                print(f"\n  LE turning point: {turning_point:.1f} years")
                print(f"  (Compare to Project 1 CA/GDP turning point: ~65 years)")

        # Also test with Z polynomials AND LE
        combo_vars = ['Z_1', 'Z_2', 'Z_3', 'life_expectancy', 'le_sq'] + controls
        est_combo = df.dropna(subset=[dep_var] + combo_vars).copy()
        if len(est_combo) >= 200:
            m_combo, r_combo = fit_and_report(
                est_combo[dep_var].values, est_combo[combo_vars].values,
                est_combo['iso3'].values, est_combo['year'].values,
                combo_vars, "Z + LE Quadratic → Japanification"
            )
            all_results.append(r_combo)

    # =================================================================
    # 3. Speed of aging (Δold_dep)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  3. Speed of Aging")
    print("=" * 70)

    if 'delta_old_dep' in df.columns:
        speed_vars = ['old_dep', 'delta_old_dep'] + controls
        est_speed = df.dropna(subset=[dep_var] + speed_vars).copy()

        if len(est_speed) >= 200:
            m_speed, r_speed = fit_and_report(
                est_speed[dep_var].values, est_speed[speed_vars].values,
                est_speed['iso3'].values, est_speed['year'].values,
                speed_vars, "Speed of Aging (Δold_dep)"
            )
            all_results.append(r_speed)

        # Also with Z polynomials
        speed_z_vars = ['Z_1', 'Z_2', 'Z_3', 'delta_old_dep'] + controls
        est_speed_z = df.dropna(subset=[dep_var] + speed_z_vars).copy()

        if len(est_speed_z) >= 200:
            m_sz, r_sz = fit_and_report(
                est_speed_z[dep_var].values, est_speed_z[speed_z_vars].values,
                est_speed_z['iso3'].values, est_speed_z['year'].values,
                speed_z_vars, "Z + Speed of Aging"
            )
            all_results.append(r_sz)

    # =================================================================
    # 4. Income interaction ("getting old before getting rich")
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  4. Income Interaction")
    print("=" * 70)

    if 'gdp_pc_ppp' in df.columns:
        df['log_gdp_pc'] = np.log(df['gdp_pc_ppp'].clip(lower=100))

        # Z × log(GDP/cap) interactions
        for zv in ['Z_1', 'Z_2', 'Z_3']:
            df[f'{zv}_x_lgdppc'] = df[zv] * df['log_gdp_pc']

        income_int = [f'{zv}_x_lgdppc' for zv in ['Z_1', 'Z_2', 'Z_3']]
        income_vars = ['Z_1', 'Z_2', 'Z_3', 'log_gdp_pc'] + income_int + controls
        est_inc = df.dropna(subset=[dep_var] + income_vars).copy()

        if len(est_inc) >= 200:
            m_inc, r_inc = fit_and_report(
                est_inc[dep_var].values, est_inc[income_vars].values,
                est_inc['iso3'].values, est_inc['year'].values,
                income_vars, "Income Interaction (Z × log GDP/cap)"
            )
            all_results.append(r_inc)

        # Simple high/low income split
        median_gdp = df['gdp_pc_ppp'].median()
        for label, mask in [('High income', df['gdp_pc_ppp'] >= median_gdp),
                            ('Low income', df['gdp_pc_ppp'] < median_gdp)]:
            sub = df[mask].dropna(subset=[dep_var] + ['Z_1', 'Z_2', 'Z_3'] + controls)
            if len(sub) >= 100:
                m_sub, r_sub = fit_and_report(
                    sub[dep_var].values, sub[['Z_1', 'Z_2', 'Z_3'] + controls].values,
                    sub['iso3'].values, sub['year'].values,
                    ['Z_1', 'Z_2', 'Z_3'] + controls, f"Income Split: {label}"
                )
                all_results.append(r_sub)

    # =================================================================
    # 5. Time-varying relationship (rolling windows)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  5. Time-Varying Relationship")
    print("=" * 70)

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    base_vars = demo_vars + controls
    window_rows = []

    for start_year in range(1990, 2011):
        end_year = start_year + 14
        win = df[(df['year'] >= start_year) & (df['year'] <= end_year)]
        win_est = win.dropna(subset=[dep_var] + base_vars)

        if len(win_est) < 100:
            continue

        model = PanelGLS()
        model.fit(win_est[dep_var].values, win_est[base_vars].values,
                  win_est['iso3'].values, win_est['year'].values)

        row = {
            'window_start': start_year,
            'window_end': end_year,
            'N': model.n_obs,
            'N_countries': model.n_countries,
            'R_squared': model.r_squared,
        }
        for i, zv in enumerate(demo_vars):
            row[f'{zv}_coef'] = model.beta[i]
            row[f'{zv}_pval'] = model.pvalues[i]
        window_rows.append(row)

    window_df = pd.DataFrame(window_rows)
    if len(window_df) > 0:
        window_df.to_csv(TABLE_DIR / "phase4_rolling_windows.csv", index=False)
        print("\n  Rolling Window Results:")
        print(window_df[['window_start', 'window_end', 'N', 'R_squared',
                          'Z_1_coef', 'Z_1_pval']].to_string(index=False, float_format='%.4f'))

    # Structural break test: pre-GFC vs post-GFC
    print("\n  --- Structural Break: Pre/Post GFC ---")
    for label, mask in [('Pre-GFC (1990-2007)', df['year'] <= 2007),
                        ('Post-GFC (2009-2024)', df['year'] >= 2009)]:
        sub = df[mask].dropna(subset=[dep_var] + base_vars)
        if len(sub) >= 100:
            m_br, r_br = fit_and_report(
                sub[dep_var].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, f"Break: {label}"
            )
            all_results.append(r_br)

    # =================================================================
    # Save all results
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase4_regression_results.csv", index=False)
        print(f"\nSaved: {TABLE_DIR / 'phase4_regression_results.csv'}")
        print(f"  {results_df['model'].nunique()} models")

    return all_results


if __name__ == "__main__":
    results = main()
